library(tidyverse)
library(knitr)
library(plotly) ; library(viridis) ; library(gridExtra) ; library(RColorBrewer) ; library(ggpubr)
library(reshape2)
library(mgcv) # GAM
library(caret) ; library(DMwR) ; library(ROCR) ; library(car) ; library(MLmetrics)
library(knitr) ; library(kableExtra)
library(ROCR)
library(expss)

SFARI_colour_hue = function(r) {
  pal = c('#FF7631','#FFB100','#E8E328','#8CC83F','#62CCA6','#59B9C9','#b3b3b3','#808080','gray','#d9d9d9')[r]
}
# Gandal dataset
load('./../Data/preprocessed_data.RData')
datExpr = datExpr %>% data.frame
rownames(datExpr) = datGenes$ensembl_gene_id
DE_info = DE_info %>% data.frame
rownames(DE_info) = datGenes$ensembl_gene_id
datMeta = datMeta %>% mutate(ID = title)


# Ridge Regression output
load('./../Data/Ridge_model.RData')

# SFARI Genes
SFARI_genes = read_csv('./../../../SFARI/Data/SFARI_genes_01-03-2020_w_ensembl_IDs.csv')
SFARI_genes = SFARI_genes[!duplicated(SFARI_genes$ID) & !is.na(SFARI_genes$ID),]


# GO Neuronal annotations: regex 'neuron' in GO functional annotations and label the genes that make a match as neuronal
GO_annotations = read.csv('./../Data/genes_GO_annotations.csv')
GO_neuronal = GO_annotations %>% filter(grepl('neuron', go_term)) %>% 
              mutate('ID'=as.character(ensembl_gene_id)) %>% 
              dplyr::select(-ensembl_gene_id) %>% distinct(ID) %>%
              mutate('Neuronal'=1)

# Add all this info to predictions
biased_predictions = predictions %>% left_join(SFARI_genes %>% dplyr::select(ID, `gene-score`), by = 'ID') %>%
                     mutate(gene.score = ifelse(is.na(`gene-score`), 
                                                ifelse(ID %in% GO_neuronal$ID, 'Neuronal', 'Others'), 
                                                `gene-score`)) %>%
                      dplyr::select(-`gene-score`)

clustering_selected = 'DynamicHybrid'
clusterings = read_csv('./../Data/clusters.csv')
clusterings$Module = clusterings[,clustering_selected] %>% data.frame %>% unlist %>% unname
assigned_module = clusterings %>% dplyr::select(ID, Module)


rm(rownames_dataset, GO_annotations, datGenes, dds, clustering_selected,
   clusterings)


Weighting Technique Implementation


Introduction


Problem


As it can be seen in 20_04_08_Ridge.html, there is a relation between the Probability the model assigns to a gene and the gene’s mean level of expression. This is a problem because we had previously discovered a bias in the SFARI scores related to mean level of expression (Preprocessing/Gandal/AllRegions/RMarkdowns/20_04_03_SFARI_genes.html), which means that this could be a confounding factor in our model and the reason why it seems to perform well, so we need to remove this bias to recover the true biological signal that is mixed with it and improve the quality of our model.


General idea


train model with equal weights for all samples

for l in loop:
  calculate bias
  correct weights to reduce bias
  retrain model
  
Return last model
  


Pseudocode


Parameters:

  • eta: Learning rate

  • T: Number of loops

  • D: Training data

  • H: Classification model

  • c: bias constraint

  • lambda: scaling factor for the weights

  • \(w_i\) with \(i=1,...,N\): Weights assigned to each sample


Pseudocode:

lambda = 0
w = [1, ..., 1]
c = std(meanExpr(D))

h  = train classifier H with lambda and w

for t in 1,,,T do
  bias = <h(x), c(x)>
  update lambda to lambda - eta*bias
  update weights_hat to exp(lambda*mean(c))
  update weights to w_hat/(1+w_hat) if y_i=1, 1/(1+w_hat) if y_i=0
  update h with new weights
  
Return h



Remove Bias


Demographic Parity as a measure of Bias


Using Demographic Parity as a measure of bias: A fair classifier h should make positive predictions each segment \(G\) of the population at the same rate as in all of the population

This definition is for discrete segments of the population. Since our bias is found across all the population but in different measures depending on the mean level of expression of the gene, we have to adapt this definition to a continuous bias scenario

Demographic Parity for our problem: A fair classifier h should make positive predictions on genes with a certail mean level of expression at the same rate as in all of the genes in the dataset


Bias Metric


The original formula for the Demographic Parity bias is

  • $c(x,0) = 0 $ when the prediction is negative

  • \(c(x,1) = \frac{g(x)}{Z_G}-1\) when the prediction is positive. Where \(g(x)\) is the Kronecker delta to indicate if the sample belongs to the protected group and \(Z_G\) is the proportion of the population that belongs to the group we want to protect against bias


Using this definitions in our problem:

\(g(x):\) Since all our samples belong to the protected group, this would always be 1

\(Z_G:\) Since all of our samples belong to the protected group, this would also always be 1

So our measure of bias \(c(x,1) = \frac{1}{1}-1 = 0\) for all samples. This doesn’t work, so we need to adapt it to our continous case


Adaptation of the bias metric


We can use \(c(x,1) = std(meanExpr(x))\) as the constraint function, this way, when we calculate the bias of the dataset:

\(h(x)\cdot c(x)\) will only be zero if the positive samples are balanced around the mean expression, and the sign of the bias will indicate the direction of the bias


Calculating the Weights


Notes:

  • This model is only going to be used to obtain the optimal weights for the observations. The weights obtained in this model are going to be the ones used to train the final model afterwards

  • In the original model we would use oversampling as part of the training function, but here, since we are adding weights to the samples, the oversampling will have to be performed separately (otherwise the weights don’t match the number of observations), so it was moved inside the create_train_test_sets function

### DEFINE FUNCTIONS

create_train_test_sets = function(p, seed){
  
  # Get SFARI Score of all the samples so our train and test sets are balanced for each score
  sample_scores = dataset %>% mutate(ID = rownames(.)) %>% dplyr::select(ID) %>% 
                  left_join(biased_predictions %>% dplyr::select(ID, gene.score), by = 'ID') %>% 
                  mutate(gene.score = ifelse(is.na(gene.score), 'None', gene.score))

  set.seed(seed)
  train_idx = createDataPartition(sample_scores$gene.score, p = p, list = FALSE)
  train_set = dataset[train_idx,]
  test_set = dataset[-train_idx,]
  
  # Modify SFARI label in train set, save gene IDS (bc we lose them with SMOTE) and perform oversampling using SMOTE
  train_set = train_set %>% mutate(SFARI = ifelse(SFARI == TRUE, 'SFARI', 'not_SFARI') %>% as.factor,
                                   ID = rownames(.) %>% as.factor) %>% SMOTE(form = SFARI ~ . - ID)
  train_set_IDs = train_set %>% pull(ID)
  
  return(list('train_set' = train_set %>% dplyr::select(-ID), 'test_set' = test_set, 
              'train_set_IDs' = train_set_IDs))
}

run_weights_model = function(p, seed, Loops){
  
  # CREATE TRAIN AND TEST SETS
  train_test_sets = create_train_test_sets(p, seed)
  train_set = train_test_sets[['train_set']]
  test_set = train_test_sets[['test_set']]
  train_set_IDs = train_test_sets[['train_set_IDs']]
  
  
  # SET INITIAL PARAMETERS
  
  # General parameters
  set.seed(seed)
  lambda_seq = 10^seq(1, -4, by = -.1)
  k_fold = 10
  cv_repeats = 5
  trControl = trainControl(method = 'repeatedcv', number = k_fold, repeats = cv_repeats, verboseIter = FALSE, 
                           classProbs = TRUE, savePredictions = 'final', summaryFunction = twoClassSummary)
  # Bias correction parameters
  eta = 0.5
  lambda = 0
  w = rep(1, nrow(train_set))
  
  
  # TRAIN MODEL
  h = train(SFARI ~., data = train_set, method = 'glmnet', trControl = trControl, metric = 'ROC',
            tuneGrid = expand.grid(alpha = 0, lambda = lambda_seq))
  
  
  # CORRECT BIAS
  
  # Mean Expression info
  mean_expr = data.frame('ID' = train_set_IDs) %>% 
              left_join(data.frame('ID' = rownames(datExpr), 'meanExpr' = rowMeans(datExpr)), by = 'ID') %>%
              mutate('meanExpr_std' = (meanExpr-mean(meanExpr))/sd(meanExpr))
  
  # Track behaviour of plot
  bias_vec = c()
  acc_vec = c()
  
  for(l in 1:Loops){
    
    # Calculate bias for positive predicted samples
    bias = mean(mean_expr$meanExpr_std[predict(h, train_set)=='SFARI'])
    
    # Update weights
    lambda = lambda - eta*bias
    w_hat = exp(lambda*mean_expr$meanExpr_std)
    w = 1/(1+w_hat)
    w[train_set$SFARI=='SFARI'] = w[train_set$SFARI=='SFARI']*w_hat[train_set$SFARI=='SFARI']
    
    # Update tracking vars
    bias_vec = c(bias_vec, bias)
    acc_vec = c(acc_vec, mean(predict(h, train_set) == train_set$SFARI))
    
    # Update h
    h = train(SFARI ~., data = train_set, method = 'glmnet', weights = w, trControl = trControl, 
              metric = 'ROC', tuneGrid = expand.grid(alpha = 0, lambda = lambda_seq))
  }

  
  return(list('lambda' = lambda, 'bias_vec' = bias_vec, 'acc_vec' = acc_vec))
}


### RUN MODEL TO FIND OPTIMAL WEIGHTS

# Parameters
p = 0.75
seed = 123
Loops = 50

# Run model
model_output = run_weights_model(p, seed, Loops)

# Extract metrics
lambda = model_output[['lambda']]
bias_vec = model_output[['bias_vec']]
acc_vec = model_output[['acc_vec']]


rm(p, seeds, Loops, run_weights_model)

The optimal value of \(\lambda\) is -1.4970506


The bias decreases until it oscilates around zero and the accuracy is not affected much

plot_info = data.frame('iter' = 1:length(bias_vec), 'bias' = bias_vec, 'accuracy' = acc_vec) %>% 
            melt(id.vars = 'iter')

plot_info %>% ggplot(aes(x=iter, y=value, color = variable)) + geom_line() + xlab('Iteration') + theme_minimal()

Since the bias increases the probability of being classified as 1 for genes with higher levels of expression, as the level of expression of a gene increases, the algorithm:

  • Increases the weight of genes with a negative label

  • Decreases the weight of genes with a positive label

mean_expr = data.frame('ID' = rownames(datExpr), 'meanExpr' = rowMeans(datExpr)) %>%
            left_join(predictions, by = 'ID') %>% filter(n>0) %>%
            mutate('meanExpr_std' = (meanExpr-mean(meanExpr))/sd(meanExpr))

w_hat = exp(lambda*mean_expr$meanExpr_std) # inverso a mean expr
w0 = 1/(1+w_hat) # prop a mean expr
w = 1/(1+w_hat)
w[mean_expr$SFARI %>% as.logical] = w[mean_expr$SFARI %>% as.logical]*w_hat[mean_expr$SFARI %>% as.logical] # inv mean expr Positives, prop Negatives
plot_data = data.frame('meanExpr' = mean_expr$meanExpr, 'w_hat' = w_hat, 'w0' = w0, 'w' = w, 
                       'SFARI' = mean_expr$SFARI, 'pred' = mean_expr$pred)

plot_data %>% ggplot(aes(meanExpr, w, color = SFARI)) + geom_point(alpha = 0.3) + 
              xlab('Mean Expression') + ylab('Weight') +
              ggtitle('Weights of the final model') + ylim(c(0,1)) + theme_minimal()

rm(mean_expr, w_hat, w0, w)


Running the final model

run_final_model = function(p, seed, lambda){
  
  # CREATE TRAIN AND TEST SETS
  train_test_sets = create_train_test_sets(p, seed)
  train_set = train_test_sets[['train_set']]
  test_set = train_test_sets[['test_set']]
  train_set_IDs = train_test_sets[['train_set_IDs']]
  
  
  # SET INITIAL PARAMETERS
  
  # General parameters
  set.seed(seed)
  lambda_seq = 10^seq(1, -4, by = -.1)
  k_fold = 10
  cv_repeats = 5
  trControl = trainControl(method = 'repeatedcv', number = k_fold, repeats = cv_repeats, verboseIter = FALSE, 
                           classProbs = TRUE, savePredictions = 'final', summaryFunction = twoClassSummary)
  
  # Bias correcting parameters
  mean_expr = data.frame('ID' = train_set_IDs) %>% 
              left_join(data.frame('ID' = rownames(datExpr), 'meanExpr' = rowMeans(datExpr)), by = 'ID') %>%
              mutate('meanExpr_std' = (meanExpr-mean(meanExpr))/sd(meanExpr))
  w_hat = exp(lambda*mean_expr$meanExpr_std)
  w = 1/(1+w_hat)
  w[train_set$SFARI=='SFARI'] = w[train_set$SFARI=='SFARI']*w_hat[train_set$SFARI=='SFARI']
  
  
  # TRAIN MODEL
  fit = train(SFARI ~., data = train_set, method = 'glmnet', weights = w, trControl = trControl, 
              metric = 'ROC', tuneGrid = expand.grid(alpha = 0, lambda = lambda_seq))
  
  
  # PREDICT TEST SET LABELS AND CREATE PERFORMANCE METRICS
  
  # Predict labels in test set
  predictions = fit %>% predict(test_set, type = 'prob')
  preds = data.frame('ID' = rownames(test_set), 'corrected_prob' = predictions$SFARI) %>% 
          mutate(corrected_pred = corrected_prob > 0.5)
  

  # Measure performance of the model
  acc = mean(test_set$SFARI==preds$corrected_pred)
  prec = Precision(test_set$SFARI %>% as.numeric, preds$corrected_pred %>% as.numeric, positive = '1')
  rec = Recall(test_set$SFARI %>% as.numeric, preds$corrected_pred %>% as.numeric, positive = '1')
  F1 = F1_Score(test_set$SFARI %>% as.numeric, preds$corrected_pred %>% as.numeric, positive = '1')
  pred_ROCR = prediction(preds$corrected_prob, test_set$SFARI)
  AUC = performance(pred_ROCR, measure='auc')@y.values[[1]]
  
  # Extract coefficients from features
  coefs = coef(fit$finalModel, fit$bestTune$lambda) %>% as.vector
  
  
  return(list('acc' = acc, 'prec' = prec, 'rec' = rec, 'F1' = F1, 'AUC' = AUC, 'preds' = preds, 'coefs'= coefs))
}


### RUN MODEL

# Parameters
p = 0.75
n_iter = 25
seeds = 123:(123+n_iter-1)

# So the input is the same as in 10_classification_model.html
original_dataset = dataset %>% mutate(ID = rownames(.)) %>% 
                   left_join(biased_predictions %>% dplyr::select(ID, gene.score))

# Store outputs
acc = c()
prec = c()
rec = c()
F1 = c()
AUC = c()
predictions = data.frame('ID' = rownames(dataset), 'SFARI' = dataset$SFARI, 'corrected_prob' = 0, 
                         'corrected_pred' = 0, 'n' = 0)
coefs = data.frame('var' = c('Intercept', colnames(dataset[,-ncol(dataset)])), 'coef' = 0)

for(seed in seeds){
  
  # Run model
  model_output = run_final_model(p, seed, lambda)
  
  # Update outputs
  acc = c(acc, model_output[['acc']])
  prec = c(prec, model_output[['prec']])
  rec = c(rec, model_output[['rec']])
  F1 = c(F1, model_output[['F1']])
  AUC = c(AUC, model_output[['AUC']])
  preds = model_output[['preds']]
  coefs$coef = coefs$coef + model_output[['coefs']]
  update_preds = preds %>% dplyr::select(-ID) %>% mutate(n=1)
  predictions[predictions$ID %in% preds$ID, c('corrected_prob','corrected_pred','n')] = 
    predictions[predictions$ID %in% preds$ID, c('corrected_prob','corrected_pred','n')] +
     update_preds
}

coefs = coefs %>% mutate(coef = coef/n_iter)
predictions = predictions %>% mutate(corrected_prob = corrected_prob/n, corrected_pred_count = corrected_pred, 
                                     corrected_pred = corrected_prob>0.5) %>% 
              left_join(biased_predictions %>% dplyr::select(ID, prob, pred), by = 'ID')


rm(p, seeds, update_preds, create_train_test_sets, run_final_model)
test_set = predictions %>% filter(n>0) %>% 
           left_join(dataset %>% mutate(ID = rownames(.)) %>% dplyr::select(ID, GS, MTcor), by = 'ID')
rownames(test_set) = predictions$ID[predictions$n>0]


Results


The relation between the model probability and the mean level of expression of the genes is not completely gone, there seems to be a negative relation for the genes with the lowest levels of expression

Even though the trend line is not as flat as with the first method, we are not fixing this directly as we were doing before, this is now just a consequence of the corrections we did inside of the model, so it makes sense for it to be less exact than before

# Plot results
plot_data = data.frame('ID'=rownames(datExpr), 'meanExpr'=rowMeans(datExpr)) %>% 
            right_join(test_set, by='ID')

plot_data %>% ggplot(aes(meanExpr, corrected_prob)) + geom_point(alpha=0.2, color='#0099cc') +
              geom_smooth(method='gam', color='gray', alpha=0.2) + 
              xlab('Mean Expression') + ylab('Corrected Probability') +
              theme_minimal() + ggtitle('Mean expression vs Model Probability corrected using adjusted Weights')



Performance metrics


Confusion matrix

conf_mat = test_set %>% apply_labels(SFARI = 'Actual Labels', 
                                     corrected_prob = 'Assigned Probability', 
                                     corrected_pred = 'Label Prediction')

cro(conf_mat$SFARI, list(conf_mat$corrected_pred, total()))
 Label Prediction     #Total 
 FALSE   TRUE   
 Actual Labels 
   FALSE  11588 967   12555
   TRUE  562 78   640
   #Total cases  12150 1045   13195
rm(conf_mat)


Accuracy: Mean = 0.8554 SD = 0.0354


Precision: Mean = 0.0722 SD = 0.0184


Recall: Mean = 0.1617 SD = 0.0515


F1 score: Mean = 0.0969 SD = 0.0228


ROC Curve: Mean = 0.5636 SD = 0.0255

pred_ROCR = prediction(test_set$corrected_prob, test_set$SFARI)

roc_ROCR = performance(pred_ROCR, measure='tpr', x.measure='fpr')
auc = performance(pred_ROCR, measure='auc')@y.values[[1]]

plot(roc_ROCR, main=paste0('ROC curve (AUC=',round(mean(AUC),2),')'), col='#009999')
abline(a=0, b=1, col='#666666')


Lift Curve

lift_ROCR = performance(pred_ROCR, measure='lift', x.measure='rpp')
plot(lift_ROCR, main='Lift curve', col='#86b300')

rm(pred_ROCR, roc_ROCR, AUC, lift_ROCR, acc, acc_vec, auc, bias_vec, F1, prec, rec)




Coefficients


gene_corr_info = dataset %>% mutate('ID' = rownames(dataset)) %>% dplyr::select(ID, MTcor, SFARI) %>% 
                 left_join(assigned_module, by ='ID') %>% mutate(Module = gsub('#','',Module))

coef_info = coefs %>% mutate('feature' = gsub('MM.','',var)) %>% 
            left_join(gene_corr_info, by = c('feature' = 'Module')) %>% 
            dplyr::select(feature, coef, MTcor, SFARI) %>% group_by(feature, coef, MTcor) %>% 
            summarise('SFARI_perc' = mean(SFARI)) %>% arrange(desc(coef))

coef_info %>% dplyr::select(feature, coef) %>% filter(feature %in% c('Intercept','GS','absGS','MTcor')) %>%
              dplyr::rename('Feature' = feature, 'Coefficient' = coef) %>% 
              kable(align = 'cc', caption = 'Regression Coefficients') %>% kable_styling(full_width = F)
Regression Coefficients
Feature Coefficient
MTcor 0.0231677
GS -0.0201093
absGS -0.0209015
Intercept -0.4946006


There is still a positive relation between the coefficient assigned to the membership of each module and the enrichment (using ORA) in SFARI genes that are assigned to that module

load('./../Data/ORA.RData')

enrichment_SFARI_info = data.frame('Module'=as.character(), 'SFARI_enrichment'=as.numeric())
for(m in names(enrichment_SFARI)){
  m_info = enrichment_SFARI[[m]]
  enrichment = 1-ifelse('SFARI' %in% m_info$ID, m_info$pvalue[m_info$ID=='SFARI'],1)
  enrichment_SFARI_info = enrichment_SFARI_info %>% 
                          add_row(Module = gsub('#','',m), SFARI_enrichment = enrichment)
}

plot_data = coef_info %>% dplyr::rename('Module' = feature) %>% 
            left_join(enrichment_SFARI_info, by = 'Module') %>% filter(!is.na(MTcor))

ggplotly(plot_data %>% ggplot(aes(coef, SFARI_enrichment)) + 
         geom_smooth(method = 'lm', color = 'gray', alpha = 0.1) + 
         geom_point(aes(id = Module), color = paste0('#',plot_data$Module), alpha=0.7) + 
         theme_minimal() + xlab('Coefficient') + 
         ylab('SFARI Genes Enrichment'))
rm(enrichment_old_SFARI, enrichment_DGN, enrichment_DO, enrichment_GO, enrichment_KEGG, enrichment_Reactome, m,
   m_info, enrichment)


ggplotly(coef_info %>% dplyr::rename('Module' = feature) %>% filter(!is.na(MTcor)) %>%
         ggplot(aes(coef, MTcor)) +  geom_smooth(method = 'lm', color = 'gray', alpha = 0.1) + 
         geom_point(aes(id = Module), color=paste0('#',coef_info$feature[!is.na(coef_info$MTcor)]), alpha=.7) + 
         theme_minimal() + xlab('Coefficient') + ylab('Module-Diagnosis correlation'))




Analyse Results


Probability distribution by SFARI Label


SFARI genes have a higher Probability distribution than the rest, but the overlap is larger than before

plot_data = test_set %>% dplyr::select(corrected_prob, SFARI)

ggplotly(plot_data %>% ggplot(aes(corrected_prob, fill=SFARI, color=SFARI)) + geom_density(alpha=0.3) + 
         geom_vline(xintercept = mean(plot_data$corrected_prob[plot_data$SFARI]), color = '#00C0C2', 
                    linetype='dashed') +
         geom_vline(xintercept = mean(plot_data$corrected_prob[!plot_data$SFARI]), color = '#FF7371', 
                    linetype='dashed') +
        xlab('Score') + ggtitle('Model Probability distribution by SFARI Label') + theme_minimal())


Probability distribution by SFARI Gene Scores


The relation between probability and SFARI Gene Scores weakened but it’s still there

plot_data = test_set %>% mutate(ID=rownames(test_set)) %>% dplyr::select(ID, corrected_prob) %>%
            left_join(original_dataset, by='ID') %>% dplyr::select(ID, corrected_prob, gene.score) %>% 
            apply_labels(gene.score='SFARI Gene score')

cro(plot_data$gene.score)
 #Total 
 SFARI Gene score 
   1  105
   2  168
   3  364
   Neuronal  782
   Others  11762
   #Total cases  13181
mean_vals = plot_data %>% group_by(gene.score) %>% summarise(mean_prob = mean(corrected_prob))

comparisons = list(c('1','2'), c('2','3'), c('3','Neuronal'), c('Neuronal','Others'),
                   c('1','3'), c('3','Others'), c('2','Neuronal'),
                   c('1','Neuronal'), c('2','Others'), c('1','Others'))
increase = 0.07
base = 0.75
pos_y_comparisons = c(rep(base, 4), rep(base + increase, 2), base + 2:5*increase)

plot_data %>% filter(!is.na(gene.score)) %>% ggplot(aes(gene.score, corrected_prob, fill=gene.score)) + 
              geom_boxplot(outlier.colour='#cccccc', outlier.shape='o', outlier.size=3) +
              stat_compare_means(comparisons = comparisons, label = 'p.signif', method = 't.test', 
                                 method.args = list(var.equal = FALSE), label.y = pos_y_comparisons, 
                                 tip.length = .02) +
              scale_fill_manual(values=SFARI_colour_hue(r=c(1:3,8,7))) + 
              ggtitle('Distribution of probabilities by SFARI score') +
              xlab('SFARI score') + ylab('Probability') + theme_minimal() + theme(legend.position = 'none')

rm(mean_vals, increase, base, pos_y_comparisons)


Genes with the highest Probabilities


  • The concentration of SFARI genes decrease from 1:4 to 1:12

  • The genes with the highest probabilities are no longer SFARI Genes

test_set %>% dplyr::select(corrected_prob, SFARI) %>% mutate(ID = rownames(test_set)) %>% 
             arrange(desc(corrected_prob)) %>% top_n(50, wt=corrected_prob) %>%
             left_join(biased_predictions %>% dplyr::select(ID, gene.score, external_gene_id, MTcor, GS), 
                       by = 'ID') %>%
             dplyr::rename('GeneSymbol' = external_gene_id, 'Probability' = corrected_prob, 
                           'ModuleDiagnosis_corr' = MTcor, 'GeneSignificance' = GS) %>%
             mutate(ModuleDiagnosis_corr = round(ModuleDiagnosis_corr,4), Probability = round(Probability,4), 
                    GeneSignificance = round(GeneSignificance,4)) %>%
             left_join(assigned_module, by = 'ID') %>%
             dplyr::select(GeneSymbol, GeneSignificance, ModuleDiagnosis_corr, Module, Probability,
                           gene.score) %>%
             kable(caption = 'Genes with highest model probabilities from the test set') %>% 
             kable_styling(full_width = F)
Genes with highest model probabilities from the test set
GeneSymbol GeneSignificance ModuleDiagnosis_corr Module Probability gene.score
TRPC6 0.1888 0.4910 #E28900 0.6947 2
PLXDC2 0.3497 0.5272 #F564E4 0.6781 Others
DIP2C -0.1778 0.2573 #FA7377 0.6654 2
MB21D2 -0.0983 -0.0793 #63B200 0.6601 Others
CPT1C -0.0969 -0.0102 #00BDD3 0.6448 Others
PCDH15 0.1827 0.3849 #C89800 0.6397 3
TMEM184A 0.1764 0.0319 #00B6EA 0.6371 Others
CMTM7 -0.1504 -0.0683 #FF6B93 0.6338 Others
FBXL13 0.1544 0.5621 #4BB400 0.6325 Others
GABRA5 0.5898 0.4329 #00C0BF 0.6319 Neuronal
OCA2 -0.2946 -0.3932 #FE6E8A 0.6314 Others
RPTOR 0.3295 0.1819 #DC8D00 0.6299 Others
PRKCE -0.0371 0.4910 #E28900 0.6288 Others
CSF2RA -0.3589 -0.4038 #00B4EF 0.6285 Others
RPRD2 0.1194 0.3738 #FC7181 0.6264 Others
FA2H -0.2586 -0.2919 #00BFC3 0.6263 Others
OR51I1 0.1205 0.4910 #E28900 0.6256 Others
TMEM211 -0.4070 -0.4513 #FF699C 0.6249 Others
CACNA1D -0.4023 -0.2919 #00BFC3 0.6226 2
PLXNC1 0.5660 0.5125 #89AC00 0.6217 Others
TUSC3 0.1125 -0.0102 #00BDD3 0.6217 Others
MYRF -0.3458 -0.2919 #00BFC3 0.6214 Others
GABBR2 0.1204 0.2418 #A68AFF 0.6187 3
IQCF3 0.4313 0.6982 #FF68A0 0.6167 Others
CACNA1C -0.1527 0.2418 #A68AFF 0.6164 1
IL4R -0.1025 -0.1266 #00BA3B 0.6139 Others
IFI16 -0.4044 -0.6465 #00C0B8 0.6138 Others
ANKRD2 0.0212 -0.0683 #FF6B93 0.6109 Others
ACAP1 -0.3687 -0.1564 #C39A00 0.6082 Others
OPRM1 0.3906 0.4910 #E28900 0.6081 Others
MX2 -0.3168 -0.0728 #DE8C00 0.6074 Others
KIAA1644 -0.0078 -0.1197 #9EA700 0.6071 Others
WIPF1 -0.1209 0.0860 #96A900 0.6069 Others
FAM98B -0.2066 -0.3187 #D39300 0.6069 Others
CD86 -0.4779 -0.6465 #00C0B8 0.6068 Others
CHST8 -0.0678 -0.5295 #C69900 0.6056 Others
C14orf39 -0.1211 0.0270 #00B7E7 0.6052 Others
GRIA1 0.2054 0.4910 #E28900 0.6047 2
IL1B -0.6507 -0.8841 #E76CF3 0.6042 Others
LRRC16B 0.0365 -0.0793 #63B200 0.6032 Others
CREBBP -0.0191 0.0401 #00BF7F 0.6025 1
POLR1E -0.3134 -0.4513 #FF699C 0.6024 Others
C1orf95 -0.2008 0.1945 #D59100 0.6024 Others
SHISA9 -0.0285 0.2203 #F17D50 0.6020 Neuronal
RAE1 0.3609 0.4329 #00C0BF 0.6013 Others
CECR1 0.2609 0.2029 #7AAE00 0.6013 Others
CHST1 0.2286 0.4910 #E28900 0.6005 Others
PLXNA1 0.2434 0.2418 #A68AFF 0.6005 Others
TP53I11 -0.1272 0.5125 #89AC00 0.6001 Others
PRELP 0.2110 0.5125 #89AC00 0.6000 Others





Negative samples distribution


The objective of this model is to identify candidate SFARI genes. For this, we are going to focus on the negative samples (the non-SFARI genes)

negative_set = test_set %>% filter(!SFARI)

negative_set_table = negative_set %>% apply_labels(corrected_prob = 'Assigned Probability', 
                                                   corrected_pred = 'Label Prediction')

cro(negative_set_table$corrected_pred)
 #Total 
 Label Prediction 
   FALSE  11588
   TRUE  967
   #Total cases  12555

967 genes are predicted as ASD-related


negative_set %>% ggplot(aes(corrected_prob)) + geom_density(color='#F8766D', fill='#F8766D', alpha=0.5) +
                 geom_vline(xintercept=0.5, color='#333333', linetype='dotted') + xlab('Probability') +
                 ggtitle('Probability distribution of the Negative samples in the Test Set') + 
                 theme_minimal()


negative_set %>% dplyr::select(corrected_prob, SFARI) %>% mutate(ID = rownames(negative_set)) %>% 
                 arrange(desc(corrected_prob)) %>% top_n(50, wt=corrected_prob) %>%
                 left_join(biased_predictions %>% dplyr::select(ID, gene.score, external_gene_id, MTcor, GS), 
                           by = 'ID') %>%
                 dplyr::rename('GeneSymbol' = external_gene_id, 'Probability' = corrected_prob, 
                               'ModuleDiagnosis_corr' = MTcor, 'GeneSignificance' = GS) %>%
                 mutate(ModuleDiagnosis_corr = round(ModuleDiagnosis_corr,4), 
                        Probability = round(Probability,4), 
                        GeneSignificance = round(GeneSignificance,4)) %>%
                 left_join(assigned_module, by = 'ID') %>%
                 dplyr::select(GeneSymbol, GeneSignificance, ModuleDiagnosis_corr, Module, Probability,
                               gene.score) %>%
                 kable(caption = 'Genes with highest model probabilities from the Negative set') %>% 
                 kable_styling(full_width = F)
Genes with highest model probabilities from the Negative set
GeneSymbol GeneSignificance ModuleDiagnosis_corr Module Probability gene.score
PLXDC2 0.3497 0.5272 #F564E4 0.6781 Others
MB21D2 -0.0983 -0.0793 #63B200 0.6601 Others
CPT1C -0.0969 -0.0102 #00BDD3 0.6448 Others
TMEM184A 0.1764 0.0319 #00B6EA 0.6371 Others
CMTM7 -0.1504 -0.0683 #FF6B93 0.6338 Others
FBXL13 0.1544 0.5621 #4BB400 0.6325 Others
GABRA5 0.5898 0.4329 #00C0BF 0.6319 Neuronal
OCA2 -0.2946 -0.3932 #FE6E8A 0.6314 Others
RPTOR 0.3295 0.1819 #DC8D00 0.6299 Others
PRKCE -0.0371 0.4910 #E28900 0.6288 Others
CSF2RA -0.3589 -0.4038 #00B4EF 0.6285 Others
RPRD2 0.1194 0.3738 #FC7181 0.6264 Others
FA2H -0.2586 -0.2919 #00BFC3 0.6263 Others
OR51I1 0.1205 0.4910 #E28900 0.6256 Others
TMEM211 -0.4070 -0.4513 #FF699C 0.6249 Others
PLXNC1 0.5660 0.5125 #89AC00 0.6217 Others
TUSC3 0.1125 -0.0102 #00BDD3 0.6217 Others
MYRF -0.3458 -0.2919 #00BFC3 0.6214 Others
IQCF3 0.4313 0.6982 #FF68A0 0.6167 Others
IL4R -0.1025 -0.1266 #00BA3B 0.6139 Others
IFI16 -0.4044 -0.6465 #00C0B8 0.6138 Others
ANKRD2 0.0212 -0.0683 #FF6B93 0.6109 Others
ACAP1 -0.3687 -0.1564 #C39A00 0.6082 Others
OPRM1 0.3906 0.4910 #E28900 0.6081 Others
MX2 -0.3168 -0.0728 #DE8C00 0.6074 Others
KIAA1644 -0.0078 -0.1197 #9EA700 0.6071 Others
WIPF1 -0.1209 0.0860 #96A900 0.6069 Others
FAM98B -0.2066 -0.3187 #D39300 0.6069 Others
CD86 -0.4779 -0.6465 #00C0B8 0.6068 Others
CHST8 -0.0678 -0.5295 #C69900 0.6056 Others
C14orf39 -0.1211 0.0270 #00B7E7 0.6052 Others
IL1B -0.6507 -0.8841 #E76CF3 0.6042 Others
LRRC16B 0.0365 -0.0793 #63B200 0.6032 Others
POLR1E -0.3134 -0.4513 #FF699C 0.6024 Others
C1orf95 -0.2008 0.1945 #D59100 0.6024 Others
SHISA9 -0.0285 0.2203 #F17D50 0.6020 Neuronal
RAE1 0.3609 0.4329 #00C0BF 0.6013 Others
CECR1 0.2609 0.2029 #7AAE00 0.6013 Others
CHST1 0.2286 0.4910 #E28900 0.6005 Others
PLXNA1 0.2434 0.2418 #A68AFF 0.6005 Others
TP53I11 -0.1272 0.5125 #89AC00 0.6001 Others
PRELP 0.2110 0.5125 #89AC00 0.6000 Others
HHLA1 0.2522 -0.0470 #00C1A1 0.5999 Others
LHFPL2 -0.1698 -0.0102 #00BDD3 0.5995 Others
PAN3 -0.2159 -0.5696 #FD6F86 0.5991 Others
SEC16A 0.0131 0.3738 #FC7181 0.5989 Others
DLG5 0.2711 0.9122 #ADA200 0.5987 Others
PXN -0.0361 0.0961 #FD61D1 0.5982 Others
SUV39H1 -0.2674 -0.4513 #FF699C 0.5978 Others
YLPM1 0.1645 -0.0235 #FF62BC 0.5973 Others




Comparison with the original model’s probabilities:

  • The genes with the highest Probabilitiess were affected the most as a group

  • In general genes with the lowest Probabilities got their score increased and the genes with the highest scores, decreased

  • The change in Probability by gene is much larger than with the Post Processing approach

negative_set %>% mutate(diff = abs(prob-corrected_prob)) %>% 
             ggplot(aes(prob, corrected_prob, color = diff)) + geom_point(alpha=0.3) + scale_color_viridis() + 
             geom_abline(slope=1, intercept=0, color='gray', linetype='dashed') + 
             geom_smooth(color='#666666', alpha=0.5, se=TRUE, size=0.5) + coord_fixed() +
             xlab('Original probability') + ylab('Corrected probability') + theme_minimal() + theme(legend.position = 'none')

negative_set_table = negative_set %>% apply_labels(corrected_prob = 'Corrected Probability', 
                                                   corrected_pred = 'Corrected Class Prediction',
                                                   pred = 'Original Class Prediction') %>%
                     filter(!is.na(pred))

cro(negative_set_table$pred, list(negative_set_table$corrected_pred, total()))
 Corrected Class Prediction     #Total 
 FALSE   TRUE   
 Original Class Prediction 
   FALSE  9809 286   10095
   TRUE  1770 679   2449
   #Total cases  11579 965   12544

84% of the genes maintained their original predicted class

rm(negative_set_table)

Probability and Gene Significance


The relation is the opposite as before, the higher the Gene Significance, the lower the probability, with the highest probabilities corresponding to under-expressed genes

*The transparent verison of the trend line is the original trend line

negative_set %>% ggplot(aes(corrected_prob, GS, color=MTcor)) + geom_point() + 
                 geom_smooth(method='gam', color='#666666') + ylab('Gene Significance') +
                 geom_line(stat='smooth', method='gam', color='#666666', alpha=0.5, size=1.2, aes(x=prob)) +
                 geom_hline(yintercept=mean(negative_set$GS), color='gray', linetype='dashed') +
                 scale_color_gradientn(colours=c('#F8766D','white','#00BFC4')) + xlab('Corrected Probability') +
                 ggtitle('Relation between the Model\'s Corrected Probability and Gene Significance') + 
                 theme_minimal()

Summarised version of Probability vs mean expression, plotting by module instead of by gene

The difference in the trend lines between this plot and the one above is that the one above takes all the points into consideration while this considers each module as an observation by itself, so the top one is strongly affected by big modules and the bottom one treats all modules the same

The transparent version of each point and trend lines are the original values and trends before the bias correction

plot_data = negative_set %>% mutate(ID = rownames(.)) %>% left_join(assigned_module, by = 'ID') %>%
            group_by(MTcor, Module) %>% summarise(mean = mean(prob), sd = sd(prob),
                                                  new_mean = mean(corrected_prob),
                                                  new_sd = sd(corrected_prob), n = n()) %>%
            mutate(MTcor_sign = ifelse(MTcor>0, 'Positive', 'Negative')) %>% 
            dplyr::select(Module, MTcor, MTcor_sign, mean, new_mean, sd, new_sd, n) %>% distinct()
colnames(plot_data)[1] = 'ID'

ggplotly(plot_data %>% ggplot(aes(MTcor, new_mean, size=n, color=MTcor_sign)) + geom_point(aes(id = ID)) + 
         geom_smooth(method='loess', color='gray', se=FALSE) + geom_smooth(method='lm', se=FALSE) + 
         geom_point(aes(y=mean), alpha=0.3) + 
         geom_line(stat='smooth', method='loess', color='gray', se=FALSE, alpha=0.3, size=1.2, aes(y=mean)) + 
         geom_line(stat='smooth', method='lm', se=FALSE, alpha=0.3, size=1.2, aes(y=mean)) + 
         xlab('Module-Diagnosis correlation') + ylab('Mean Corrected Probability by Module') + 
         theme_minimal() + theme(legend.position='none'))


Probability and mean level of expression


To check if correcting by gene also corrected by module: Yes, the bias seems to be removed completely

mean_and_sd = data.frame(ID=rownames(datExpr), meanExpr=rowMeans(datExpr), sdExpr=apply(datExpr,1,sd))

plot_data = negative_set %>% mutate(ID=rownames(test_set)[!test_set$SFARI]) %>% 
            left_join(mean_and_sd, by='ID') %>% 
            left_join(assigned_module, by='ID')

plot_data2 = plot_data %>% group_by(Module) %>% summarise(meanExpr = mean(meanExpr), meanProb = mean(prob), 
                                                          new_meanProb = mean(corrected_prob), n=n())

ggplotly(plot_data2 %>% ggplot(aes(meanExpr, new_meanProb, size=n)) + 
         geom_point(color=plot_data2$Module) + geom_point(color=plot_data2$Module, alpha=0.3, aes(y=meanProb)) + 
         geom_smooth(method='loess', se=TRUE, color='gray', alpha=0.1, size=0.7) + 
         geom_line(stat='smooth', method='loess', se=TRUE, color='gray', alpha=0.4, size=1.2, aes(y=meanProb)) +
         xlab('Mean Expression') + ylab('Corrected Probability') +  
         ggtitle('Mean expression vs corrected Model Probability by Module') +
         theme_minimal() + theme(legend.position='none'))
rm(plot_data2, mean_and_sd)


Probability and LFC


Unider-Expressed genes got their probabilities increased and over-expressed genes decreased

plot_data = negative_set %>% mutate(ID=rownames(test_set)[!test_set$SFARI]) %>% 
            left_join(DE_info %>% data.frame %>% mutate(ID=rownames(.)), by='ID') %>%
            dplyr::rename('log2FoldChange' = logFC, 'padj' = adj.P.Val)

plot_data %>% ggplot(aes(log2FoldChange, corrected_prob)) + geom_point(alpha=0.1, color='#0099cc') + 
              geom_smooth(method='loess', color='gray', alpha=0.1) + 
              geom_line(stat='smooth', method='loess', color='gray', alpha=0.4, size=1.5, aes(y=prob)) +
              xlab('LFC') + ylab('Corrected Probability') +
              theme_minimal() + ggtitle('LFC vs model probability by gene')


Probability and Module-Diagnosis correlation


The Probabilitys increased for modules with negative correlation and decreased for modules with positive correlation

module_score = negative_set %>% mutate(ID=rownames(test_set)[!test_set$SFARI]) %>%
               left_join(biased_predictions %>% dplyr::select(ID, gene.score), by='ID') %>%
               left_join(assigned_module, by = 'ID') %>%
               dplyr::select(ID, prob, corrected_prob, Module, MTcor) %>% 
               left_join(data.frame(MTcor=unique(dataset$MTcor)) %>% arrange(by=MTcor) %>% 
                         mutate(order=1:length(unique(dataset$MTcor))), by='MTcor')

ggplotly(module_score %>% ggplot(aes(MTcor, corrected_prob)) + 
         geom_point(color=module_score$Module, aes(id=ID, alpha=corrected_prob^4)) +
         geom_hline(yintercept=mean(module_score$corrected_prob), color='gray', linetype='dotted') + 
         geom_line(stat='smooth', method = 'loess', color='gray', alpha=0.5, size=1.5, aes(x=MTcor, y=prob)) +
         geom_smooth(color='gray', method = 'loess', se = FALSE, alpha=0.3) + theme_minimal() + 
         xlab('Module-Diagnosis correlation') + ylab('Corrected Probability'))



Conclusion


This bias correction makes bigger changes in the distribution of the probabilities than the post-processing one

In general, the performance metrics decrease, but this isn’t necessarily bad, since we knew part of the good performance of the model was because of the confounding factor related to mean level of expression, so it was expected for the performance of the model to decrease once we removed this signal


Saving results

write.csv(test_set, file='./../Data/RM_weighting_bias_correction.csv', row.names = TRUE)




Session info

sessionInfo()
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 18.04.4 LTS
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
## 
## locale:
##  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
##  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
##  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] grid      stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
##  [1] expss_0.10.2       kableExtra_1.1.0   MLmetrics_1.1.1    car_3.0-7         
##  [5] carData_3.0-3      ROCR_1.0-7         gplots_3.0.3       DMwR_0.4.1        
##  [9] caret_6.0-86       lattice_0.20-41    mgcv_1.8-31        nlme_3.1-147      
## [13] reshape2_1.4.4     ggpubr_0.2.5       magrittr_1.5       RColorBrewer_1.1-2
## [17] gridExtra_2.3      viridis_0.5.1      viridisLite_0.3.0  plotly_4.9.2      
## [21] knitr_1.28         forcats_0.5.0      stringr_1.4.0      dplyr_1.0.0       
## [25] purrr_0.3.4        readr_1.3.1        tidyr_1.1.0        tibble_3.0.1      
## [29] ggplot2_3.3.2      tidyverse_1.3.0   
## 
## loaded via a namespace (and not attached):
##  [1] colorspace_1.4-1     ggsignif_0.6.0       ellipsis_0.3.1      
##  [4] class_7.3-17         rio_0.5.16           htmlTable_1.13.3    
##  [7] fs_1.4.0             rstudioapi_0.11      farver_2.0.3        
## [10] prodlim_2019.11.13   fansi_0.4.1          lubridate_1.7.4     
## [13] xml2_1.2.5           codetools_0.2-16     splines_3.6.3       
## [16] jsonlite_1.7.0       pROC_1.16.2          broom_0.5.5         
## [19] dbplyr_1.4.2         compiler_3.6.3       httr_1.4.1          
## [22] backports_1.1.8      assertthat_0.2.1     Matrix_1.2-18       
## [25] lazyeval_0.2.2       cli_2.0.2            htmltools_0.4.0     
## [28] tools_3.6.3          gtable_0.3.0         glue_1.4.1          
## [31] Rcpp_1.0.4.6         cellranger_1.1.0     vctrs_0.3.1         
## [34] gdata_2.18.0         crosstalk_1.1.0.1    iterators_1.0.12    
## [37] timeDate_3043.102    gower_0.2.1          xfun_0.12           
## [40] openxlsx_4.1.4       rvest_0.3.5          lifecycle_0.2.0     
## [43] gtools_3.8.2         MASS_7.3-51.6        zoo_1.8-8           
## [46] scales_1.1.1         ipred_0.9-9          hms_0.5.3           
## [49] yaml_2.2.1           quantmod_0.4.17      curl_4.3            
## [52] rpart_4.1-15         stringi_1.4.6        highr_0.8           
## [55] foreach_1.5.0        checkmate_2.0.0      TTR_0.23-6          
## [58] caTools_1.18.0       zip_2.0.4            shape_1.4.4         
## [61] lava_1.6.7           matrixStats_0.56.0   rlang_0.4.6         
## [64] pkgconfig_2.0.3      bitops_1.0-6         evaluate_0.14       
## [67] labeling_0.3         recipes_0.1.10       htmlwidgets_1.5.1   
## [70] tidyselect_1.1.0     plyr_1.8.6           R6_2.4.1            
## [73] generics_0.0.2       DBI_1.1.0            pillar_1.4.4        
## [76] haven_2.2.0          foreign_0.8-76       withr_2.2.0         
## [79] xts_0.12-0           survival_3.1-12      abind_1.4-5         
## [82] nnet_7.3-14          modelr_0.1.6         crayon_1.3.4        
## [85] KernSmooth_2.23-17   rmarkdown_2.1        readxl_1.3.1        
## [88] data.table_1.12.8    ModelMetrics_1.2.2.2 webshot_0.5.2       
## [91] reprex_0.3.0         digest_0.6.25        glmnet_3.0-2        
## [94] stats4_3.6.3         munsell_0.5.0